# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for open_spiel.python.algorithms.adidas_utils.helpers.nonsymmetric.exploitability.

Computing the exploitability of a tsallis-entropy regularized game is more
involved, so we include a derivation here of an example test case using an
asymmetric prisoner's dilemma (see pd np.array below). Note that the
tsallis-entropy setting assumes non-negative payoffs so we add 3 to the array.
We assume p=1 for the tsallis entropy in this example.

dist = [(1/3, 2/3), (1/2, 1/2)]

-- Player 1 --
pt    dist    grad     br       payoff(br)     payoff(dist)
[2 0] [1/2] = [1] --> [1/3] --> 5/3        --> 5/3
[3 1] [1/2]   [2]     [2/3]

s = sum(grad) = 3

tsallis-entr(br) = s / (p + 1) * (1 - br_1^2 - br_2^2)
                 = 3 / 2 * (1 - 1/9 - 4/9) = 2/3

tsallis-entr(dist) = s / (p + 1) * (1 - dist_1^2 - dist_2^2)
                   = 3 / 2 * (1 - 1/9 - 4/9) = 2/3

u_1(br_1) - u_1(dist) = 5/3 + 2/3 - 5/3 - 2/3 = 0

-- Player 2 --
pt    dist    grad     br       payoff(br)     payoff(dist)
[3 0] [1/3] = [1] --> [1/3] --> 5/3        --> 3/2
[4 1] [2/3]   [2]     [2/3]

s = sum(grad) = 3

tsallis-entr(br) = s / (p + 1) * (1 - br_1^2 - br_2^2)
                 = 3 / 2 * (1 - 1/9 - 4/9) = 2/3

tsallis-entr(dist) = s / (p + 1) * (1 - dist_1^2 - dist_2^2)
                   = 3 / 2 * (1 - 1/4 - 1/4) = 3/4

u_2(br_2) - u_2(dist) = 5/3 + 2/3 - 3/2 - 3/4 = 7 / 3 - 9 / 4
"""

from absl import logging  # pylint:disable=unused-import
from absl.testing import absltest
from absl.testing import parameterized

import numpy as np

from open_spiel.python.algorithms.adidas_utils.helpers.nonsymmetric import exploitability


test_seed = 12345

# asymmetric prisoner's dilemma test case
# pylint:disable=bad-whitespace
pt_r = np.array([[2, 0],
                 [3, 1]])
pt_c = np.array([[3, 4],
                 [0, 1]])
# pylint:enable=bad-whitespace
pd = np.stack((pt_r, pt_c), axis=0)
pd_nash = [np.array([0, 1]), np.array([0, 1])]
pd_non_nash_1 = [np.array([1, 0]), np.array([1, 0])]
pd_non_nash_exp_1 = np.array([1., 1.])
pd_non_nash_ate_exp_1 = np.array([9. / 5., 16. / 7.])
pd_non_nash_2 = [np.array([1., 2.]) / 3., np.array([0.5, 0.5])]
pd_non_nash_exp_2 = np.array([1. / 3., 0.5])
pd_non_nash_ate_exp_2 = np.array([0., 7. / 3. - 9. / 4.])

qre_br_1 = np.exp([1, 2]) / np.exp([1, 2]).sum()
qre_br_2 = np.copy(qre_br_1)
entr_br_1 = -np.sum(qre_br_1 * np.log(qre_br_1))
entr_br_2 = -np.sum(qre_br_2 * np.log(qre_br_2))
entr_non_nash_2_1 = -np.sum(pd_non_nash_2[0] * np.log(pd_non_nash_2[0]))
entr_non_nash_2_2 = -np.sum(pd_non_nash_2[1] * np.log(pd_non_nash_2[1]))
u_br_minus_non_nash_1 = (qre_br_1 - pd_non_nash_2[0]).dot([1, 2])
u_br_minus_non_nash_2 = (qre_br_2 - pd_non_nash_2[1]).dot([1, 2])
pd_non_nash_qre_exp_2_1 = u_br_minus_non_nash_1 + entr_br_1 - entr_non_nash_2_1
pd_non_nash_qre_exp_2_2 = u_br_minus_non_nash_2 + entr_br_2 - entr_non_nash_2_2
pd_non_nash_qre_exp_2 = np.array([pd_non_nash_qre_exp_2_1,
                                  pd_non_nash_qre_exp_2_2])

# rock-paper-scissors test case (nonsymmetric should work for symmetric as well)
# pylint:disable=bad-whitespace
pt_r = np.array([[0, -1,  1],
                 [1,  0, -1],
                 [-1, 1,  0]])
# pylint:enable=bad-whitespace
pt_r -= pt_r.min()
pt_c = pt_r.T
rps = np.stack((pt_r, pt_c), axis=0)
rps_nash = [np.ones(3) / 3., np.ones(3) / 3.]
rps_non_nash_1 = [np.array([1, 0, 0]), np.array([1, 0, 0])]
rps_non_nash_exp_1 = np.array([1., 1.])
rps_non_nash_2 = [np.array([0, 1, 0]), np.array([0, 1, 0])]
rps_non_nash_exp_2 = np.array([1., 1.])
rps_non_nash_3 = [np.array([0, 0, 1]), np.array([0, 0, 1])]
rps_non_nash_exp_3 = np.array([1., 1.])

# two-player game with different numbers of actions
# pylint:disable=bad-whitespace
pt_r = np.array([[2, 2],
                 [3, 0],
                 [0, 3]])
pt_c = np.array([[2, 1, 0],
                 [3, 0, 1]]).T
# pylint:enable=bad-whitespace
rect = [pt_r, pt_c]
rect_unreg_nash = [np.array([0, 1, 0]), np.array([1, 0])]
rect_unreg_nash_ate_exp = np.array([4. / 5., 0.])
qre_br_1 = np.exp([2, 3, 0]) / np.exp([2, 3, 0]).sum()
qre_br_2 = np.exp([1, 0]) / np.exp([1, 0]).sum()
entr_br_1 = -np.sum(qre_br_1 * np.log(qre_br_1))
entr_br_2 = -np.sum(qre_br_2 * np.log(qre_br_2))
entr_non_nash_2_1 = 0.
entr_non_nash_2_2 = 0.
u_br_minus_dist_1 = (qre_br_1 - rect_unreg_nash[0]).dot([2, 3, 0])
u_br_minus_dist_2 = (qre_br_2 - rect_unreg_nash[1]).dot([1, 0])
rect_qre_exp_1 = u_br_minus_dist_1 + entr_br_1 - entr_non_nash_2_1
rect_qre_exp_2 = u_br_minus_dist_2 + entr_br_2 - entr_non_nash_2_2
rect_unreg_nash_qre_exp = np.array([rect_qre_exp_1, rect_qre_exp_2])


class ExploitabilityTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('PD_nash', pd, pd_nash),
      ('RPS_nash', rps, rps_nash),
      ('RECT_nash', rect, rect_unreg_nash),
      )
  def test_unreg_exploitability_of_nash(self, payoff_tensor, nash):
    exp = exploitability.unreg_exploitability(nash, payoff_tensor, np.max)
    self.assertEqual(exp, 0., 'nash should have zero exploitability')

  @parameterized.named_parameters(
      ('PD_non_nash_1', pd, pd_non_nash_1, pd_non_nash_exp_1),
      ('PD_non_nash_2', pd, pd_non_nash_2, pd_non_nash_exp_2),
      ('RPS_non_nash_1', rps, rps_non_nash_1, rps_non_nash_exp_1),
      ('RPS_non_nash_2', rps, rps_non_nash_2, rps_non_nash_exp_2),
      ('RPS_non_nash_3', rps, rps_non_nash_3, rps_non_nash_exp_3),
      )
  def test_unreg_exploitability_of_non_nash(self, payoff_tensor, dist, exp):
    no_op = lambda x: x
    exp_pred = exploitability.unreg_exploitability(dist, payoff_tensor, no_op)
    equal = np.allclose(exp_pred, exp)
    msg = 'exploitability mismatch: pred={}, true={}'.format(exp_pred, exp)
    self.assertTrue(equal, msg)

  @parameterized.named_parameters(
      ('PD_rand', pd, test_seed),
      ('RPS_rand', rps, test_seed),
      ('RECT_rand', rect, test_seed),
      )
  def test_unreg_exploitability_of_rand(self, payoff_tensor, seed=None):
    trials = 100
    random = np.random.RandomState(seed)
    num_strategies = payoff_tensor[0].shape
    total_num_strategies = sum(num_strategies)
    pseudo_dists = random.rand(trials, total_num_strategies)
    exploitable = []
    for pseudo_dist in pseudo_dists:
      # first split and normalize pseudo_dist into strat for each player
      pseudo_dist_i = np.split(pseudo_dist, np.cumsum(num_strategies)[:-1])
      dist = [pdi / pdi.sum() for pdi in pseudo_dist_i]
      exp = exploitability.unreg_exploitability(dist, payoff_tensor, np.max)
      exploitable.append(exp > 0.)
    perc = 100 * np.mean(exploitable)
    logging.info('rand strat exploitable rate out of %d is %f', trials, perc)
    self.assertEqual(perc, 100., 'found rand strat that was nash')

  @parameterized.named_parameters(
      ('RPS_nash_p=0', rps, rps_nash, 0.),
      ('RPS_nash_p=0.1', rps, rps_nash, 0.1),
      ('RPS_nash_p=1', rps, rps_nash, 1.),
      )
  def test_ate_exploitability_of_nash(self, payoff_tensor, nash, p):
    exp = exploitability.ate_exploitability(nash, payoff_tensor, p, np.max)
    self.assertGreaterEqual(0., exp,
                            'uniform nash should have zero exploitability')

  @parameterized.named_parameters(
      ('PD_non_nash_p=0', pd, 0., pd_non_nash_1, pd_non_nash_exp_1),
      ('PD_non_nash_p=1', pd, 1., pd_non_nash_2, pd_non_nash_ate_exp_2),
      ('RECT_non_nash_p=0', rect, 1., rect_unreg_nash, rect_unreg_nash_ate_exp),
      )
  def test_ate_exploitability_of_non_nash(self, payoff_tensor, p, dist, exp):
    no_op = lambda x: x
    exp_pred = exploitability.ate_exploitability(dist, payoff_tensor, p, no_op)
    close = np.allclose(exp_pred, exp)
    msg = 'exploitability mismatch: pred={}, true={}'.format(exp_pred, exp)
    self.assertTrue(close, msg=msg)

  @parameterized.named_parameters(
      ('RPS_rand_p=0', rps, 0., test_seed),
      ('RPS_rand_p=0.1', rps, 0.1, test_seed),
      ('RPS_rand_p=1', rps, 1., test_seed),
      ('RECT_rand_p=1', rect, 1., test_seed),
      )
  def test_ate_exploitability_of_rand(self, payoff_tensor, p, seed=None):
    trials = 100
    random = np.random.RandomState(seed)
    num_strategies = payoff_tensor[0].shape
    total_num_strategies = sum(num_strategies)
    pseudo_dists = random.rand(trials, total_num_strategies)
    exploitable = []
    for pseudo_dist in pseudo_dists:
      # first split and normalize pseudo_dist into strat for each player
      pseudo_dist_i = np.split(pseudo_dist, np.cumsum(num_strategies)[:-1])
      dist = [pdi / pdi.sum() for pdi in pseudo_dist_i]
      exp = exploitability.ate_exploitability(dist, payoff_tensor, p, np.max)
      exploitable.append(exp > 0.)
    perc = 100 * np.mean(exploitable)
    logging.info('rand strat exploitable rate out of %d is %f', trials, perc)
    self.assertEqual(perc, 100., 'found rand strat that was nash')

  @parameterized.named_parameters(
      ('RPS_nash_tau=0', rps, rps_nash, 0.),
      ('RPS_nash_tau=0.1', rps, rps_nash, 0.1),
      ('RPS_nash_tau=1', rps, rps_nash, 1.),
      )
  def test_qre_exploitability_of_nash(self, payoff_tensor, nash, temperature):
    exp = exploitability.qre_exploitability(nash, payoff_tensor, temperature,
                                            np.max)
    self.assertGreaterEqual(1e-10, exp,
                            'uniform nash should have zero exploitability')

  @parameterized.named_parameters(
      ('PD_non_nash_tau=0', pd, 0., pd_non_nash_1, pd_non_nash_exp_1),
      ('PD_non_nash_tau=1', pd, 1., pd_non_nash_2, pd_non_nash_qre_exp_2),
      ('RECT_non_nash_tau=1', rect, 1., rect_unreg_nash,
       rect_unreg_nash_qre_exp),
      )
  def test_qre_exploitability_of_non_nash(self, payoff_tensor, temperature,
                                          dist, exp):
    no_op = lambda x: x
    exp_pred = exploitability.qre_exploitability(dist, payoff_tensor,
                                                 temperature, no_op)
    close = np.allclose(exp_pred, exp)
    msg = 'exploitability mismatch: pred={}, true={}'.format(exp_pred, exp)
    self.assertTrue(close, msg=msg)

  @parameterized.named_parameters(
      ('RPS_rand_tau=0', rps, 0., test_seed),
      ('RPS_rand_tau=0.1', rps, 0.1, test_seed),
      ('RPS_rand_tau=1', rps, 1., test_seed),
      ('RECT_rand_tau=1', rect, 1., test_seed),
      )
  def test_qre_exploitability_of_rand(self, payoff_tensor, temperature,
                                      seed=None):
    trials = 100
    random = np.random.RandomState(seed)
    num_strategies = payoff_tensor[0].shape
    total_num_strategies = sum(num_strategies)
    pseudo_dists = random.rand(trials, total_num_strategies)
    exploitable = []
    for pseudo_dist in pseudo_dists:
      # first split and normalize pseudo_dist into strat for each player
      pseudo_dist_i = np.split(pseudo_dist, np.cumsum(num_strategies)[:-1])
      dist = [pdi / pdi.sum() for pdi in pseudo_dist_i]
      exp = exploitability.qre_exploitability(dist, payoff_tensor, temperature,
                                              np.max)
      exploitable.append(exp > 0.)
    perc = 100 * np.mean(exploitable)
    logging.info('rand strat exploitable rate out of %d is %f', trials, perc)
    self.assertEqual(perc, 100., 'found rand strat that was nash')


if __name__ == '__main__':
  absltest.main()
